{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torchvision\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.loss import BarlowTwinsLoss\n",
    "from lightly.models.modules import BarlowTwinsProjectionHead\n",
    "from lightly.transforms.byol_transform import (\n",
    "    BYOLTransform,\n",
    "    BYOLView1Transform,\n",
    "    BYOLView2Transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BarlowTwins(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        resnet = torchvision.models.resnet18()\n",
    "        self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n",
    "        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)\n",
    "\n",
    "        # enable gather_distributed to gather features from all gpus\n",
    "        # before calculating the loss\n",
    "        self.criterion = BarlowTwinsLoss(gather_distributed=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x).flatten(start_dim=1)\n",
    "        z = self.projection_head(x)\n",
    "        return z\n",
    "\n",
    "    def training_step(self, batch, batch_index):\n",
    "        (x0, x1) = batch[0]\n",
    "        z0 = self.forward(x0)\n",
    "        z1 = self.forward(x1)\n",
    "        loss = self.criterion(z0, z1)\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optim = torch.optim.SGD(self.parameters(), lr=0.06)\n",
    "        return optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BarlowTwins()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# BarlowTwins uses BYOL augmentations.\n",
    "# We disable resizing and gaussian blur for cifar10.\n",
    "transform = BYOLTransform(\n",
    "    view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n",
    "    view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n",
    ")\n",
    "dataset = torchvision.datasets.CIFAR10(\n",
    "    \"datasets/cifar10\", download=True, transform=transform\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n",
    "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n",
    "trainer = pl.Trainer(\n",
    "    max_epochs=10,\n",
    "    devices=\"auto\",\n",
    "    accelerator=\"gpu\",\n",
    "    strategy=\"ddp\",\n",
    "    sync_batchnorm=True,\n",
    "    use_distributed_sampler=True,  # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n",
    ")\n",
    "trainer.fit(model=model, train_dataloaders=dataloader)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
